# encoding: utf-8

from model.model import Transformer
from model.LMConfig import LMConfig

config = LMConfig()
model = Transformer(config)


def count_parameters(_model):
    return sum(p.numel() for p in _model.parameters() if p.requires_grad)


params = count_parameters(model)
print(f'LLM总参数量：{count_parameters(model) / 1024 /1024:.3f} M, {count_parameters(model) / 1024 /1024 /1024:.3f} B')

